# This class implements fast inference for the DN-LR model by precomputing the parts that remains constant.
# We are able to perform these pre-computations because of the properties of DN and gibbs sampling
import copy

import numpy as np
from sklearn.metrics import accuracy_score
from numba import njit


# Parts to train the network
@njit
def _sigmoid(x):
    return np.array([_sigmoid_function(value) for value in x])


@njit
def _sigmoid_function(x):
    if x >= 0:
        z = np.exp(-x)
        return 1 / (1 + z)
    else:
        z = np.exp(x)
        return z / (1 + z)


@njit
def compute_loss(y_true, y_pred, l1_weight, weights, bias):
    # binary cross entropy
    # Adding a small value to remove the log(0) error
    y_zero_loss = y_true * np.log(y_pred + 1e-9)
    y_one_loss = (1 - y_true) * np.log(1 - y_pred + 1e-9)
    mean_cross_entropy = -np.mean(y_zero_loss + y_one_loss)
    mean_cross_entropy_with_l1 = mean_cross_entropy + l1_weight * (np.sum(np.abs(weights)) + bias)
    return mean_cross_entropy_with_l1


@njit
def compute_gradients(x, y_true, y_pred):
    """
    Function to compute the gradients for the LR model

    """
    # derivative of binary cross entropy
    n, k = x.shape
    y_pred = y_pred.reshape((n, 1))
    difference = (np.subtract(y_pred, y_true)).reshape((n, 1))
    gradient_b = np.mean(difference)
    gradients_w = x.transpose().astype(np.float64) @ difference.astype(np.float64)
    gradients_w = np.array([np.mean(grad) for grad in gradients_w])
    return gradients_w, gradient_b


@njit
def l1_loss_grad(w):
    return np.array([d_abs(each_w) for each_w in w])


@njit
def d_abs(x):
    mask = (x >= 0) * 1.0
    mask2 = (x < 0) * -1.0
    return mask + mask2


@njit
def update_model_parameters(grad_w, grad_b, weights, bias, learning_rate):
    """
    Function to update the model parameters
    """
    weights = weights - learning_rate * grad_w
    bias = bias - learning_rate * grad_b
    return weights, bias


@njit
def _fit_for_one_iteration(x, y, weights, bias, learning_rate, l1_weight):
    """
    Function used to fit the model to given data (x,y)
    """
    wTx = x.astype(np.float64) @ weights.astype(np.float64) + bias
    pred = _sigmoid(wTx)
    loss = compute_loss(y, pred, l1_weight, weights, bias)
    grad_w, grad_b = compute_gradients(x, y, pred)
    # Add the l1 components
    grad_w = grad_w + l1_weight * l1_loss_grad(weights)
    grad_b = grad_b + l1_weight * l1_loss_grad([bias])[0]
    weights, bias = update_model_parameters(grad_w, grad_b, weights, bias, learning_rate)
    return loss, weights, bias


# For Predictions
@njit
def predict_classes(x, weights, bias):
    """
    Function used to predict the class for the given LR model
    """
    probabilities = predict_probs(x, weights, bias)
    return np.array([1 if p > 0.5 else 0 for p in probabilities])


@njit
def predict_probs(x, weights, bias):
    """
    This function returns the probabilities (output of LR model) for the given input x
    """
    wTx = x.astype(np.float64) @ weights.astype(np.float64) + bias
    probabilities = _sigmoid(wTx)
    return probabilities


@njit
def accuracy_score(true_values, predictions):
    accuracy = (true_values == predictions).sum() / true_values.shape[0]
    return accuracy


@njit
def pre_compute_evidence_part_for_one_example(evidence_part, weights, bias, num_evidence_nodes):
    """
    We use this function to perform the pre-computation for the evidence part
    """
    weights_for_evidence_variables = weights[num_evidence_nodes - 1:]
    logits_for_evidence_variables = evidence_part.astype(np.float64) @ weights_for_evidence_variables.astype(
            np.float64) + bias
    return logits_for_evidence_variables


@njit
def compute_true_label_part_for_one_example(true_label_values, weights, num_evidence_nodes):
    """
    We use this function to perform the computation for the true label part

    """
    weights_for_true_label_variables = weights[:num_evidence_nodes - 1]
    logits_for_true_label_variables = true_label_values.astype(np.float64) @ weights_for_true_label_variables.astype(
            np.float64)
    return logits_for_true_label_variables


class LogisticRegression:
    def __init__(self, _x_train, _y_train, _x_val, _y_val, learning_rate, stochastic, epochs, batch_size, l1_weight):
        self.losses = []
        self.val_accuracies = []
        self.train_accuracies = []
        self.weights = None
        self.bias = None
        self.learning_rate = learning_rate
        self.stochastic = stochastic
        self.batch_size = batch_size
        self.epochs = epochs
        self.num_iter_in_epoch = None
        self._x_train = _x_train
        self._y_train = _y_train
        self._x_val = _x_val
        self._y_val = _y_val
        self.l1_weight = l1_weight
        # self.num_evidence_nodes = None
        self.num_evidence_nodes = self._x_val.shape[1] // 2 + 1

    def fit(self):
        x_ = self._transform_x(self._x_train)
        y_ = self._transform_y(self._y_train)
        epochs = self.epochs
        nexamples = x_.shape[0]
        self.num_iter_in_epoch = nexamples // self.batch_size if self.stochastic else 1
        self.weights = np.random.rand(x_.shape[1], ).astype(dtype=np.float64)
        self.bias = np.random.rand()
        for epoch in range(epochs):
            epoch_losses = []
            for _ in range(self.num_iter_in_epoch):
                if self.stochastic:
                    indices = np.random.randint(x_.shape[0], size=self.batch_size)
                    x = x_[indices, :]
                    y = y_[indices, :]
                else:
                    x = x_
                    y = y_
                loss, weights, bias = _fit_for_one_iteration(x, y, self.weights, self.bias, self.learning_rate,
                                                             self.l1_weight)
                self.weights, self.bias = weights, bias
                epoch_losses.append(loss)
            pred_to_class_val = predict_classes(self._x_val, self.weights, self.bias)
            pred_to_class_train = predict_classes(self._x_train, self.weights, self.bias)
            probs_train = predict_probs(self._x_train, self.weights, self.bias)
            train_loss = compute_loss(self._y_train, probs_train, self.l1_weight, self.weights, self.bias)
            val_accuracy = accuracy_score(self._y_val, pred_to_class_val)
            train_accuracy = accuracy_score(self._y_train, pred_to_class_train)
            self.val_accuracies.append(val_accuracy)
            self.train_accuracies.append(train_accuracy)
            self.losses.append(train_loss)
            if epoch + 1 % (epochs // 10) == 0:
                print(f"Epoch:{str(epoch)}" + "\n Validation Accuracy: " + str(val_accuracy))
                print(f"Train -  Loss: {str(train_loss)} Accuracy: {str(train_accuracy)}")

    def truncate_weights(self, threshold):
        self.weights[np.abs(self.weights) < threshold] = 0

    def _transform_x(self, x):
        x = copy.deepcopy(x)
        return np.array(x, dtype=np.float64)

    def _transform_y(self, y):
        y = copy.deepcopy(y)
        return np.array(y.reshape(y.shape[0], 1), dtype=np.float64)

    def pre_compute_evidence_part_for_one_example(self, evidence_part):
        return pre_compute_evidence_part_for_one_example(evidence_part, self.weights, self.bias,
                                                         self.num_evidence_nodes)

    def compute_true_label_part_for_one_example(self, true_label_values):
        return compute_true_label_part_for_one_example(true_label_values, self.weights, self.num_evidence_nodes)

    def sigmoid(self, x):
        if isinstance(self, list):
            return np.array([_sigmoid_function(value) for value in x])
        else:
            return _sigmoid_function(x)
